{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "import cv2\n",
    "import torch\n",
    "import numpy as np\n",
    "from torch.utils.data import Dataset,DataLoader\n",
    "from torchvision import transforms,models\n",
    "\n",
    "from keras.datasets import mnist\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 加载数据\n",
    "可视化数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = mnist.load_data()\n",
    "train_x, train_y, test_x, test_y = train[0], train[1], test[0], test[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAATwAAAChCAYAAABaigMvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAD5xJREFUeJzt3XuUVNWVBvBvd/MSbB1asCWAL2zNEB8kEmACE83yhTIZ1LiM5MUYFI26ouYlY+JKVkwyOBNJRnGYQGLsjDFoNEYcMQokGXygAoYgKEKLKHSgaR4REkT6seePup3Vp3ZV1+veW3XrfL+1WF37cKruVjabW/dWnSOqCiIiH9SUOwEioriw4RGRN9jwiMgbbHhE5A02PCLyBhseEXmDDY+IvMGGR0TeKKnhichkEXldRJpFZFZYSRF1Y41RmKTYb1qISC2AjQDOA7ANwEoA01T11WzP6Sf9dQAGFXU8Srb92LtLVYcW8pxCa4z15a9866tPCccYB6BZVTcDgIgsBDAVQNaGNwCDMF7OKeGQlFRL9eG3inhaQTXG+vJXvvVVylva4QC29oi3BWMOEZkpIqtEZFU73ivhcOShnDXG+qJCRH7TQlXnq+pYVR3bF/2jPhx5hvVFhSil4bUAGNkjHhGMEYWFNUahKqXhrQTQKCIniEg/AFcAWBROWkQAWGMUsqJvWqhqh4jcAOApALUA7lXV9aFlRt5jjVHYSrlLC1VdDGBxSLkQGawxChO/aUFE3ijpDI+Iotc8Z4IZ++6UB514wXWXmjl9lq2OLKek4hkeEXmDDY+IvMGGR0TeYMMjIm/wpgVRhTlwyXgnnj91gZnT0j7YiXeMs1+rG7Es3LyqAc/wiMgbbHhE5A02PCLyBq/hEZVR7VH1ZuyHc+524suWXG/mnHL9H514pL5k5hS3lnl14xkeEXmDDY+IvMGGR0TeKOkanohsAbAfQCeADlUdG0ZSRN1YYxSmMG5afExVd4XwOkTZVG2NNX/5FDPW1vmcE4+e3WbmdLQfiiynasa3tETkjVIbngJ4WkRWi8jMTBO4jR6VqNcaY31RIUp9SztJVVtE5GgAS0Rkg6ou7zlBVecDmA8AR0g9PxpEheq1xlhfVIhS97RoCX7uFJFHkdopfnnvz/JETa0T9mkYaqYcGnWMGWv+dL+cL/3MlDlOPKLP4WbOG+1/ceKp875m5gyf/XzOY5VbtdfYQ9N+aMYufeKLTty4+cW40ql6Rb+lFZFBIlLX/RjA+QDWhZUYEWuMwlbKGV4DgEdFpPt1HlDV34SSFVEKa4xCVcq+tJsBnBFiLkQO1hiFjR9LISJvcLWUItQOdW9AtHyq0czRj+114tUfvj+0429sd2+ILN13tJnTfPA0Jx755F4zpyu0jChf6auj1Ne2mzlHbKw1YxQOnuERkTfY8IjIG2x4ROQNXsMrwobbTnTi1z9xd5aZhXut3b2m07T7I2bO6m+c6cT9n1yZzyuXkhaFpPUTdrGAdMMffduJO6JKxkM8wyMib7DhEZE32PCIyBtseETkDd60yOHNhaebsRcmzkkbGWDmvNN10Ik/+qOvmjlHvdppxg5rddd0k+fWmDn9kc9NCqpEp0zf4MR7OvuaOR1bt8WVjnd4hkdE3mDDIyJv5Gx4InKviOwUkXU9xupFZImIbAp+Do42TapmrDGKSz7X8O4DMBfAz3qMzQKwTFVni8isIL4l/PTK73OjXzJjg2vsNbt06w7VOfHI71T+6sJldB+qscZS6/g5Tq37kxPPfPUzZs5gbIospXQHLhlvxrZflntHtM533GuPDc/Zc6cjH0hbqVnLvwJ/zjO8YP+APWnDUwE0BY+bAFwccl7kEdYYxaXYa3gNqro9eLwDqZVpicLEGqPQlXzTQlUVqa30MuI2elSq3mqM9UWFKLbhtYrIMAAIfu7MNlFV56vqWFUd2xf9izwceSivGmN9USGK/eDxIgDTAcwOfj4WWkYV5v4NHzZjt0xcn/N5Vz3q7hk9Ci+ElpMnEl9jtaOON2O3HPWIE//yv8/J8MxwblrUDLA31zbMPdWJmy+cZ+Y8fuAIJ978nl1Re2nb+5347ikPmTmf7fiKE9c9WP6/A/l8LOUXAFYAOEVEtonIDKSK8DwR2QTg3CAmKgprjOKS8wxPVadl+a1M/zQRFYw1RnHhNy2IyBtcPCCHw35fZwcnuuF7aneeGrHMLgxAlG5gW0h1UmN3Otv6wCgz1jx+vhOfPvcGM+fY/3QXrOg6cCDDAd0PUF9xpV0cY9btP3fin/zWfsi5s60tw2tHh2d4ROQNNjwi8gYbHhF5gw2PiLzBmxYhOKj2wnN+WydSNTvQOCTnnCN/v9mMFXMbo/lndmXun475qRn76E1fcOIRD68wc7qKWNVkyMPrzNjQ2/a5A0cebp/ImxZERNFgwyMib7DhEZE3eA2PKCIHGqL769XnhOOceN6E+82cW796jRk7/JEXzVgYuvbvN2MLd09w4h3nHmPmDG1+M5J8suEZHhF5gw2PiLzBhkdE3ih2m8ZviUiLiKwJfl0UbZpUzVhjFJdit2kEgB+o6vdDz4h8dB+qsMZqD+X+AG/HSe8zY5LHh3GbZ7jPmzTgr2ZO3ZOvmLGunK8cnfY6u21l3IrdppEoNKwxiksp1/BuEJG1wduRrLvCc1cpKkHOGmN9USGKbXjzAIwCMAbAdgB3ZpvIXaWoSHnVGOuLClFUw1PVVlXtVNUuAAsAjAs3LfIda4yiUNRHwUVkWI9d4S8BYJdKqBLve/xtM7biK+5y2mf0s/9u1JzubmPXtXZDuIlVuWqoscFPbTRjz9zu/pVrvtYuzd5oFzAxjnnBXVNl4Of7mTnvfNyuoBLVVonS1x7/uAG7nfilPxe+CkvYcja8YAu9swEMEZFtAL4J4GwRGYPUbvBbANjvsBDliTVGcSl2m8afRJALeYo1RnHhNy2IyBtcLSWHjq3bzNifOwc68UCxa9T+668XOvEf3z3OzMnkrifcLxQ03vmGmdPZujOv16Ly6txtP1r49L5Tnfh//vHHZs7tfd1VRrT9kJkzYNdBJ27PsOp2V4x/u7fcdqYZO2vQXCde/viJZk5HZBllxjM8IvIGGx4ReYMNj4i8wYZHRN7gTYsi3PHGZCe+8LRfmjn/0L8zLbbb8WVy7afdC71XTjrHzHn7390vHRz265fyem0qv9/8aKITf/O21WbOxh+7NzYap79sX+iFtU74geWfN1PmfXuBGbt6wlVOXPtu7nOeYc/bGyL7jnVbx4or7aI2/3zjzU48cEc0y8sXgmd4ROQNNjwi8gYbHhF5Q1Tj+0LvEVKv48Vek0qcGvcL31u+bRfyqF/v/n9t+5Bd7fXqyUvN2Jfqcy8ycPLia9346pU5n1NuS/Xh1ao6NspjJLG+9j7RaMaWnOEu/Dzmf280c0bP3uHEXW27zZxdl9vFAw4OSavDDIsQd/Z143dPsusMnv337sIIb996spnT57f2+mRU8q0vnuERkTfY8IjIG2x4ROSNfLZpHCkivxORV0VkvYjcGIzXi8gSEdkU/My6rwVRNqwvilPOmxYiMgzAMFV9WUTqAKwGcDGAfwGwR1Vni8gsAINV9ZbeXiuJF5Wj1OfE483YJxc/68TT6lrNnO/tOs2JV5w50MzRjrjXoehdtovKvteX9Lf7cLx+l3uzYf2Ue8ycp9+td+Kbn7nCzOnX0teMpTv7gjVm7L+GP+fE0948z8zZ+/Vjnbjm//6Q81hRCu2mhapuV9WXg8f7AbwGYDiAqQCagmlNSBUpUUFYXxSngr5aJiLHA/gggBcBNPTYc2AHgIYsz5kJYCYADIA9EyHqxvqiqOV900JEDgfwCICbVHVfz9/T1PvijO+NuY0e5YP1RXHI6wxPRPoiVYw/V9VfBcOt3TtLBddhuAxvgTo2bzFjdzRd7sSTr/sPM+fWIa848cdrP5LhxSvrGl5vfK4vfc9+qPfka9wPkk+54DozZ8tl7ieGp421X8y/7tznzdhVzZ904qXPn2HmnPWsew1x0K9WmTk1XfaDzkmQz11aQWpDlddUdU6P31oEYHrweDqAx8JPj6od64vilM8Z3kQAnwXwioh039K5FcBsAA+JyAwAbwG4PMvziXrD+qLY5LNN47PI+I07AECyPgNAFYf1RXHiNy2IyBtc8bjCjPg390Lzg58ZbeZc+3f5rZ5M1aHfU/amwclPufHqDOcuMzApw6u1ONFJaXG14xkeEXmDDY+IvMGGR0Te4DW8ClN70glOfGL/3CsgE1F+eIZHRN5gwyMib7DhEZE32PCIyBu8aVFhNtx4tBOff9hfzZw5e97vDnR2RpkSUdXgGR4ReYMNj4i8UcquZd8SkRYRWRP8uij6dKnasL4oTvlcw+sA8OWeu0qJyJLg936gqt+PLj3/DFmV9m/QpXbOQ3PPdZ/TsSLCjCLH+qLY5LMe3nYA24PH+0Wke1cpopKxvihOBV3DS9tVCgBuEJG1InJvto2SRWSmiKwSkVXtsOv3E3VjfVHUStm1bB6AUQDGIPUv9J2ZnsddpSgfrC+KQ14NL9OuUqraqqqdqtoFYAGAcdGlSdWM9UVxyXkNL9uuUt1b6AXhJQDWRZOiXwY3uTcg/qnpTDNnCBJ9k8LB+qI4lbJr2TQRGYPUBslbAFwTSYZU7VhfFJtSdi1bHH465BvWF8WJ37QgIm+w4RGRN9jwiMgbbHhE5A02PCLyBhseEXlDVDW+g4m0AXgLwBAAu2I7cHiSmHel5Hycqg6N8gCsr7KolJzzqq9YG97fDiqySlXHxn7gEiUx7yTmXKqk/jcnMe+k5cy3tETkDTY8IvJGuRre/DIdt1RJzDuJOZcqqf/NScw7UTmX5RoeEVE58C0tEXmDDY+IvBF7wxORySLyuog0i8isuI+fj2APhZ0isq7HWL2ILBGRTcHPjHsslEsv2x1WdN5hS0J9AcmrsWqpr1gbnojUArgHwIUARiO1yOPoOHPI030AJqeNzQKwTFUbASwL4krSvd3haAATAFwf/L+t9LxDk6D6ApJXY1VRX3Gf4Y0D0Kyqm1X1EICFAKbGnENOqrocwJ604akAmoLHTQAujjWpHFR1u6q+HDzeD6B7u8OKzjtkiagvIHk1Vi31FXfDGw5ga494G5KzB2lDjz0WdgBoKGcyvUnb7jAxeYcgyfUFJOTPKsn1xZsWRdDUZ3kq8vM8GbY7/JtKzptclfpnlfT6irvhtQAY2SMeEYwlQauIDANSO2oB2FnmfIxM2x0iAXmHKMn1BVT4n1U11FfcDW8lgEYROUFE+gG4AsCimHMo1iIA04PH0wE8VsZcjGzbHaLC8w5ZkusLqOA/q6qpL1WN9ReAiwBsBPAGgK/Hffw8c/wFUrvdtyN1HWgGgKOQugu1CcBSAPXlzjMt50lIvZ1YC2BN8OuiSs/bx/pKYo1VS33xq2VE5A3etCAib7DhEZE32PCIyBtseETkDTY8IvIGGx4ReYMNj4i88f+0Mkgl0ab90wAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 360x360 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.subplot(1,2,1)\n",
    "plt.imshow(train_x[100])\n",
    "plt.subplot(1,2,2)\n",
    "plt.imshow(test_x[100])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 编写dataloader\n",
    "数据生成器，用于训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DataGen(Dataset):\n",
    "    def __init__(self, data, label, transforms=None):\n",
    "        self.data = data\n",
    "        self.label = label\n",
    "        self.transforms = transforms\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        img = self.data[index]\n",
    "        img = np.expand_dims(img, -1)\n",
    "        tag = self.label[index]\n",
    "        if self.transforms:\n",
    "            img = self.transforms(img)\n",
    "        return img, tag\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.data.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "traingen = DataGen(train_x, train_y, transforms=transforms.Compose([transforms.ToTensor()]))\n",
    "trainloader = DataLoader(dataset=traingen,\n",
    "                         batch_size=32,\n",
    "                         pin_memory=False,\n",
    "                         drop_last=True,\n",
    "                         shuffle=True,\n",
    "                         num_workers=10)\n",
    "testgen = DataGen(test_x, test_y, transforms=transforms.Compose([transforms.ToTensor()]))\n",
    "testloader = DataLoader(dataset=testgen,\n",
    "                         batch_size=32,\n",
    "                         pin_memory=False,\n",
    "                         drop_last=True,\n",
    "                         shuffle=True,\n",
    "                         num_workers=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 创建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleModel(torch.nn.Module):\n",
    "    def __init__(self, categroies):\n",
    "        super(SimpleModel, self).__init__()\n",
    "        self.layer1 = torch.nn.Sequential(torch.nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, stride=1, padding=1, bias=False),\n",
    "                                          torch.nn.BatchNorm2d(num_features=8),\n",
    "                                          torch.nn.ReLU(),\n",
    "                                          torch.nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "        self.layer2 = torch.nn.Sequential(torch.nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False),\n",
    "                                          torch.nn.BatchNorm2d(num_features=16),\n",
    "                                          torch.nn.ReLU(),\n",
    "                                          torch.nn.MaxPool2d(kernel_size=2, stride=2))\n",
    "        self.fc = torch.nn.Linear(in_features=7*7*16, out_features=categroies)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.layer1(x)\n",
    "        x = self.layer2(x)\n",
    "        x = x.view(x.shape[0],-1)\n",
    "        x = self.fc(x)\n",
    "        return x\n",
    "    \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SimpleModel(10)\n",
    "# 交叉熵内部做softmax，所以模型的最后一层不需要加softmax激活函数\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 开始训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/10, steps:1, loss:2.3714566230773926\n",
      "1/10, steps:101, loss:0.5173688530921936\n",
      "1/10, steps:201, loss:0.2653545141220093\n",
      "1/10, steps:301, loss:0.15274108946323395\n",
      "1/10, steps:401, loss:0.22296205163002014\n",
      "1/10, steps:501, loss:0.1274479329586029\n",
      "1/10, steps:601, loss:0.04120491072535515\n",
      "1/10, steps:701, loss:0.1828516572713852\n",
      "1/10, steps:801, loss:0.06340346485376358\n",
      "1/10, steps:901, loss:0.0825238972902298\n",
      "1/10, steps:1001, loss:0.02459488995373249\n",
      "1/10, steps:1101, loss:0.11977686733007431\n",
      "1/10, steps:1201, loss:0.04187004640698433\n",
      "1/10, steps:1301, loss:0.038930751383304596\n",
      "1/10, steps:1401, loss:0.3668687343597412\n",
      "1/10, steps:1501, loss:0.02016628347337246\n",
      "1/10, steps:1601, loss:0.013538673520088196\n",
      "1/10, steps:1701, loss:0.005016855895519257\n",
      "1/10, steps:1801, loss:0.1457785665988922\n",
      "Average Accuracy is 0.9800681089743589\n",
      "2/10, steps:1, loss:0.1934467852115631\n",
      "2/10, steps:101, loss:0.17865468561649323\n",
      "2/10, steps:201, loss:0.13472557067871094\n",
      "2/10, steps:301, loss:0.08578787744045258\n",
      "2/10, steps:401, loss:0.02144831418991089\n",
      "2/10, steps:501, loss:0.005850791931152344\n",
      "2/10, steps:601, loss:0.04301856458187103\n",
      "2/10, steps:701, loss:0.11978933960199356\n",
      "2/10, steps:801, loss:0.04601867496967316\n",
      "2/10, steps:901, loss:0.02841944247484207\n",
      "2/10, steps:1001, loss:0.03699500858783722\n",
      "2/10, steps:1101, loss:0.13282698392868042\n",
      "2/10, steps:1201, loss:0.05713712424039841\n",
      "2/10, steps:1301, loss:0.005625538527965546\n",
      "2/10, steps:1401, loss:0.010827876627445221\n",
      "2/10, steps:1501, loss:0.0012373104691505432\n",
      "2/10, steps:1601, loss:0.03878188133239746\n",
      "2/10, steps:1701, loss:0.0015051215887069702\n",
      "2/10, steps:1801, loss:0.19163677096366882\n",
      "Average Accuracy is 0.9809695512820513\n",
      "3/10, steps:1, loss:0.12056343257427216\n",
      "3/10, steps:101, loss:0.001156136393547058\n",
      "3/10, steps:201, loss:0.004483334720134735\n",
      "3/10, steps:301, loss:0.021360646933317184\n",
      "3/10, steps:401, loss:0.003936804831027985\n",
      "3/10, steps:501, loss:0.006191704422235489\n",
      "3/10, steps:601, loss:0.07782182842493057\n",
      "3/10, steps:701, loss:0.009243220090866089\n",
      "3/10, steps:801, loss:0.010823681950569153\n",
      "3/10, steps:901, loss:0.05335962772369385\n",
      "3/10, steps:1001, loss:0.013054192066192627\n",
      "3/10, steps:1101, loss:0.10637931525707245\n",
      "3/10, steps:1201, loss:0.021973397582769394\n",
      "3/10, steps:1301, loss:0.08236199617385864\n",
      "3/10, steps:1401, loss:0.18045903742313385\n",
      "3/10, steps:1501, loss:0.025909535586833954\n",
      "3/10, steps:1601, loss:0.003286115825176239\n",
      "3/10, steps:1701, loss:0.030399613082408905\n",
      "3/10, steps:1801, loss:0.009306885302066803\n",
      "Average Accuracy is 0.9868790064102564\n",
      "4/10, steps:1, loss:0.009686559438705444\n",
      "4/10, steps:101, loss:0.002984777092933655\n",
      "4/10, steps:201, loss:0.011091306805610657\n",
      "4/10, steps:301, loss:0.13606426119804382\n",
      "4/10, steps:401, loss:0.09724206477403641\n",
      "4/10, steps:501, loss:0.007687561213970184\n",
      "4/10, steps:601, loss:0.004973292350769043\n",
      "4/10, steps:701, loss:0.04421800374984741\n",
      "4/10, steps:801, loss:0.01848164200782776\n",
      "4/10, steps:901, loss:0.00937969982624054\n",
      "4/10, steps:1001, loss:0.1941588670015335\n",
      "4/10, steps:1101, loss:0.0287005752325058\n",
      "4/10, steps:1201, loss:0.038601599633693695\n",
      "4/10, steps:1301, loss:0.0012191683053970337\n",
      "4/10, steps:1401, loss:0.13952608406543732\n",
      "4/10, steps:1501, loss:0.0037347376346588135\n",
      "4/10, steps:1601, loss:0.0029834285378456116\n",
      "4/10, steps:1701, loss:0.011380180716514587\n",
      "4/10, steps:1801, loss:0.0026053637266159058\n",
      "Average Accuracy is 0.9863782051282052\n",
      "5/10, steps:1, loss:0.011860810220241547\n",
      "5/10, steps:101, loss:0.0016858279705047607\n",
      "5/10, steps:201, loss:0.07329708337783813\n",
      "5/10, steps:301, loss:0.0030578821897506714\n",
      "5/10, steps:401, loss:0.08032140135765076\n",
      "5/10, steps:501, loss:0.015193246304988861\n",
      "5/10, steps:601, loss:0.0015309154987335205\n",
      "5/10, steps:701, loss:0.0669763907790184\n",
      "5/10, steps:801, loss:0.08088316023349762\n",
      "5/10, steps:901, loss:0.009713061153888702\n",
      "5/10, steps:1001, loss:0.0026463717222213745\n",
      "5/10, steps:1101, loss:0.009301483631134033\n",
      "5/10, steps:1201, loss:0.14427758753299713\n",
      "5/10, steps:1301, loss:0.007976964116096497\n",
      "5/10, steps:1401, loss:0.0054595619440078735\n",
      "5/10, steps:1501, loss:0.0026947855949401855\n",
      "5/10, steps:1601, loss:0.005468010902404785\n",
      "5/10, steps:1701, loss:0.05163765698671341\n",
      "5/10, steps:1801, loss:0.0768515020608902\n",
      "Average Accuracy is 0.9859775641025641\n",
      "6/10, steps:1, loss:0.00042782723903656006\n",
      "6/10, steps:101, loss:0.002573952078819275\n",
      "6/10, steps:201, loss:0.03102801740169525\n",
      "6/10, steps:301, loss:0.03891141712665558\n",
      "6/10, steps:401, loss:0.0023501068353652954\n",
      "6/10, steps:501, loss:0.009381204843521118\n",
      "6/10, steps:601, loss:0.01727242022752762\n",
      "6/10, steps:701, loss:0.018089167773723602\n",
      "6/10, steps:801, loss:0.0020205527544021606\n",
      "6/10, steps:901, loss:0.143434539437294\n",
      "6/10, steps:1001, loss:0.09830010682344437\n",
      "6/10, steps:1101, loss:0.002021089196205139\n",
      "6/10, steps:1201, loss:0.007840186357498169\n",
      "6/10, steps:1301, loss:0.15081346035003662\n",
      "6/10, steps:1401, loss:0.007431128993630409\n",
      "6/10, steps:1501, loss:0.20055168867111206\n",
      "6/10, steps:1601, loss:0.056012559682130814\n",
      "6/10, steps:1701, loss:0.16107991337776184\n",
      "6/10, steps:1801, loss:0.007314965128898621\n",
      "Average Accuracy is 0.9858774038461539\n",
      "7/10, steps:1, loss:0.0496373325586319\n",
      "7/10, steps:101, loss:0.0035120248794555664\n",
      "7/10, steps:201, loss:0.001310795545578003\n",
      "7/10, steps:301, loss:0.00047269463539123535\n",
      "7/10, steps:401, loss:0.11643655598163605\n",
      "7/10, steps:501, loss:0.010200954973697662\n",
      "7/10, steps:601, loss:0.014061525464057922\n",
      "7/10, steps:701, loss:0.015325464308261871\n",
      "7/10, steps:801, loss:0.011148393154144287\n",
      "7/10, steps:901, loss:0.0014718994498252869\n",
      "7/10, steps:1001, loss:0.0241160336881876\n",
      "7/10, steps:1101, loss:0.0010789558291435242\n",
      "7/10, steps:1201, loss:0.0015253052115440369\n",
      "7/10, steps:1301, loss:0.00369475781917572\n",
      "7/10, steps:1401, loss:0.004648566246032715\n",
      "7/10, steps:1501, loss:0.002583138644695282\n",
      "7/10, steps:1601, loss:0.014193188399076462\n",
      "7/10, steps:1701, loss:0.0002834200859069824\n",
      "7/10, steps:1801, loss:0.0010428130626678467\n",
      "Average Accuracy is 0.987479967948718\n",
      "8/10, steps:1, loss:0.0009018927812576294\n",
      "8/10, steps:101, loss:0.0002222806215286255\n",
      "8/10, steps:201, loss:0.011049076914787292\n",
      "8/10, steps:301, loss:0.0015708506107330322\n",
      "8/10, steps:401, loss:0.03226374089717865\n",
      "8/10, steps:501, loss:0.0016654133796691895\n",
      "8/10, steps:601, loss:0.0007573366165161133\n",
      "8/10, steps:701, loss:0.004894644021987915\n",
      "8/10, steps:801, loss:0.004497632384300232\n",
      "8/10, steps:901, loss:0.013633817434310913\n",
      "8/10, steps:1001, loss:0.007155656814575195\n",
      "8/10, steps:1101, loss:0.0003268122673034668\n",
      "8/10, steps:1201, loss:0.01945830136537552\n",
      "8/10, steps:1301, loss:0.013361334800720215\n",
      "8/10, steps:1401, loss:0.005097508430480957\n",
      "8/10, steps:1501, loss:0.038638778030872345\n",
      "8/10, steps:1601, loss:0.0006249919533729553\n",
      "8/10, steps:1701, loss:0.10886749625205994\n",
      "8/10, steps:1801, loss:0.0015304088592529297\n",
      "Average Accuracy is 0.9848758012820513\n",
      "9/10, steps:1, loss:0.008891571313142776\n",
      "9/10, steps:101, loss:0.001825764775276184\n",
      "9/10, steps:201, loss:0.11651931703090668\n",
      "9/10, steps:301, loss:0.018866300582885742\n",
      "9/10, steps:401, loss:0.0066434964537620544\n",
      "9/10, steps:501, loss:0.01165221631526947\n",
      "9/10, steps:601, loss:0.008111178874969482\n",
      "9/10, steps:701, loss:0.0027614086866378784\n",
      "9/10, steps:801, loss:0.03085269033908844\n",
      "9/10, steps:901, loss:0.0003725886344909668\n",
      "9/10, steps:1001, loss:0.0021242350339889526\n",
      "9/10, steps:1101, loss:0.0031332969665527344\n",
      "9/10, steps:1201, loss:0.0031114965677261353\n",
      "9/10, steps:1301, loss:0.009705603122711182\n",
      "9/10, steps:1401, loss:0.001462206244468689\n",
      "9/10, steps:1501, loss:0.0008941143751144409\n",
      "9/10, steps:1601, loss:0.017257481813430786\n",
      "9/10, steps:1701, loss:6.0170888900756836e-05\n",
      "9/10, steps:1801, loss:0.04438307136297226\n",
      "Average Accuracy is 0.9873798076923077\n",
      "10/10, steps:1, loss:0.019429191946983337\n",
      "10/10, steps:101, loss:0.0001784190535545349\n",
      "10/10, steps:201, loss:0.00773581862449646\n",
      "10/10, steps:301, loss:0.0026732832193374634\n",
      "10/10, steps:401, loss:0.03586462140083313\n",
      "10/10, steps:501, loss:0.011358663439750671\n",
      "10/10, steps:601, loss:0.0001553744077682495\n",
      "10/10, steps:701, loss:3.622472286224365e-05\n",
      "10/10, steps:801, loss:0.007309824228286743\n",
      "10/10, steps:901, loss:4.011392593383789e-05\n",
      "10/10, steps:1001, loss:0.003514360636472702\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10/10, steps:1101, loss:0.004707172513008118\n",
      "10/10, steps:1201, loss:0.0017499513924121857\n",
      "10/10, steps:1301, loss:0.011861233040690422\n",
      "10/10, steps:1401, loss:0.04933084547519684\n",
      "10/10, steps:1501, loss:0.007391378283500671\n",
      "10/10, steps:1601, loss:0.010394856333732605\n",
      "10/10, steps:1701, loss:0.007220812141895294\n",
      "10/10, steps:1801, loss:0.053185343742370605\n",
      "Average Accuracy is 0.987479967948718\n"
     ]
    }
   ],
   "source": [
    "epochs = 10\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "model = model.to(device)\n",
    "for i in range(epochs):\n",
    "    for j, (img, tag) in enumerate(trainloader):\n",
    "        img = img.to(device)\n",
    "        tag = tag.to(device)\n",
    "        # print(tag)\n",
    "        output = model(img)\n",
    "        loss = criterion(output, tag.type(torch.cuda.LongTensor))\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if j % 100 == 0:\n",
    "            print('{}/{}, steps:{}, loss:{}'.format(i+1, epochs, j+1, loss.item()))\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        for img, tag in testloader:\n",
    "            img = img.to(device)\n",
    "            tag = tag.to(device).type(torch.cuda.LongTensor)\n",
    "            output = model(img)\n",
    "            _, predicted = torch.max(output.data, 1)\n",
    "            correct += (predicted == tag).sum().item()\n",
    "            total += tag.shape[0]\n",
    "        print('Average Accuracy is {}'.format(correct / total))  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 保存模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), 'mnist.ckpt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 加载权重测试\n",
    "查看测试结果与训练过程中的测试结果是否保持一致"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Accuracy is 0.9875801282051282\n"
     ]
    }
   ],
   "source": [
    "model = SimpleModel(10)\n",
    "model.load_state_dict(torch.load('mnist.ckpt'))\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "model = model.to(device)\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for img, tag in testloader:\n",
    "        img = img.to(device)\n",
    "        tag = tag.to(device).type(torch.cuda.LongTensor)\n",
    "        output = model(img)\n",
    "        _, predicted = torch.max(output.data, 1)\n",
    "        correct += (predicted == tag).sum().item()\n",
    "        total += tag.shape[0]\n",
    "    print('Average Accuracy is {}'.format(correct / total)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
